{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Private Predictions with Syft Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 1: Public Training\n",
    "\n",
    "Welcome to this tutorial! In the following notebooks you will learn how to provide private predictions. By private predictions, we mean that the data is constantly encrypted throughout the entire process. At no point is the user sharing raw data, only encrypted (that is, secret shared) data. In order to provide these private predictions, Syft Keras uses a library called [TF Encrypted](https://github.com/tf-encrypted/tf-encrypted) under the hood. TF Encrypted combines cutting-edge cryptographic and machine learning techniques, but you don't have to worry about this and can focus on your machine learning application.\n",
    "\n",
    "You can start serving private predictions with only three steps:\n",
    "- **Step 1**: train your model with normal Keras.\n",
    "- **Step 2**: secure and serve your machine learning model (server).\n",
    "- **Step 3**: query the secured model to receive private predictions (client). \n",
    "\n",
    "Alright, let's go through these three steps so you can deploy impactful machine learning services without sacrificing user privacy or model security.\n",
    "\n",
    "Authors:\n",
    "- Jason Mancuso - Twitter: [@jvmancuso](https://twitter.com/jvmancuso)\n",
    "- Yann Dupis - Twitter: [@YannDupis](https://twitter.com/YannDupis)\n",
    "- Morten Dahl - Twitter: [@mortendahlcs](https://github.com/mortendahlcs)\n",
    "\n",
    "On behalf of:\n",
    "- Dropout Labs - Twitter: [@dropoutlabs](https://twitter.com/dropoutlabs)\n",
    "- TF Encrypted - Twitter: [@tf_encrypted](https://twitter.com/tf_encrypted)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train your model with Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use privacy-preserving machine learning techniques for your projects you should not have to learn a new machine learning framework. If you have basic [Keras](https://keras.io/) knowledge, you can start using these techniques with Syft Keras. If you have never used Keras before, you can learn a bit more about it through the [Keras documentation](https://keras.io). \n",
    "\n",
    "Before serving private predictions, the first step is to train your model with normal Keras. As an example, we will train a model to classify handwritten digits. To train this model we will use the canonical [MNIST dataset](http://yann.lecun.com/exdb/mnist/).\n",
    "\n",
    "We borrow [this example](https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py) from the reference Keras repository.  To train your classification model, you just run the cell below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import tensorflow.keras as keras\n",
    "from tensorflow.keras.datasets import mnist\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Dropout, Flatten\n",
    "from tensorflow.keras.layers import Conv2D, AveragePooling2D\n",
    "from tensorflow.keras.layers import Activation\n",
    "\n",
    "batch_size = 128\n",
    "num_classes = 10\n",
    "epochs = 2\n",
    "\n",
    "# input image dimensions\n",
    "img_rows, img_cols = 28, 28\n",
    "\n",
    "# the data, split between train and test sets\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "\n",
    "x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n",
    "x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n",
    "input_shape = (img_rows, img_cols, 1)\n",
    "\n",
    "x_train = x_train.astype('float32')\n",
    "x_test = x_test.astype('float32')\n",
    "x_train /= 255\n",
    "x_test /= 255\n",
    "print('x_train shape:', x_train.shape)\n",
    "print(x_train.shape[0], 'train samples')\n",
    "print(x_test.shape[0], 'test samples')\n",
    "\n",
    "# convert class vectors to binary class matrices\n",
    "y_train = keras.utils.to_categorical(y_train, num_classes)\n",
    "y_test = keras.utils.to_categorical(y_test, num_classes)\n",
    "\n",
    "model = Sequential()\n",
    "\n",
    "model.add(Conv2D(10, (3, 3), input_shape=input_shape))\n",
    "model.add(AveragePooling2D((2, 2)))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Conv2D(32, (3, 3)))\n",
    "model.add(AveragePooling2D((2, 2)))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Conv2D(64, (3, 3)))\n",
    "model.add(AveragePooling2D((2, 2)))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Flatten())\n",
    "model.add(Dense(num_classes, activation='softmax'))\n",
    "\n",
    "model.compile(loss=keras.losses.categorical_crossentropy,\n",
    "              optimizer=keras.optimizers.Adam(),\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "model.fit(x_train, y_train,\n",
    "          batch_size=batch_size,\n",
    "          epochs=epochs,\n",
    "          verbose=1,\n",
    "          validation_data=(x_test, y_test))\n",
    "score = model.evaluate(x_test, y_test, verbose=0)\n",
    "print('Test loss:', score[0])\n",
    "print('Test accuracy:', score[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save your model's weights for future private prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! Your model is trained. Let's save the model weights with `model.save()`. In the next notebook, we will load these weights in Syft Keras to start serving private predictions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save('short-conv-mnist.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
